# Device / Driver
Rust abstractions ### Kangrejos '24 Danilo Krummrich, Red Hat --- ## Motivation - patches sent in the context of the Nova (stub) driver - reference for how a (DRM) Rust Linux driver should look like - prepare the way for other Rust drivers to go upstream - other drivers that require those abstractions (rVKMS, rNVME, Asahi, cpufreq-dt) - provide Rust driver infrastructure that: - integrates well with the kernel, but takes advantage of Rust's capabilities - do not require unsafe code or unsafe APIs - is as simple to use as the C APIs --- ## Context - mailing list discussion with Greg KH on the device / driver, PCI, I/O patch series ```text rust: pci: implement I/O mappable `pci::Bar` rust: pci: add basic PCI device / driver abstractions rust: add devres abstraction rust: add `io::Io` base type rust: add `dev_*` print macros. rust: add `Revocable` type rust: add rcu abstraction rust: implement `IdArray`, `IdTable` and `RawDeviceId` rust: implement generic driver registration rust: pass module name to `Module::init` -- rust: introduce `InPlaceModule` rust: init: introduce `Opaque::try_ffi_init` ``` - some misunderstandings on what the abstraction types represent - Greg's last proposal is to implement driver registration in C --- ## Basis for discussion - implement a [sample PCI driver](https://git.kernel.org/pub/scm/linux/kernel/git/dakr/linux.git/log/?h=kangrejos) for QEMU's "pci-testdev" device - `samples/rust/rust_pci_driver/[driver.rs, mod.{c, rs}]` - start with the driver registration in a C module - convert the C code into to Rust - incrementally add abstractions to transform unsafe driver code into safe driver code
## Driver registration in C Module --- `samples/rust/rust_pci_sample/mod.c` ```C [8-9|22-23] // SPDX-License-Identifier: GPL-2.0 #include
#include
#define PCI_DEVICE_ID_REDHAT_QEMU_PCI_TESTDEV 0x0005 extern int rust_pci_driver_probe(struct pci_dev *dev, const struct pci_device_id *id); extern void rust_pci_driver_remove(struct pci_dev *dev); static const struct pci_device_id rust_pci_driver_ids[] = { { PCI_DEVICE(PCI_VENDOR_ID_REDHAT, PCI_DEVICE_ID_REDHAT_QEMU_PCI_TESTDEV) }, {} }; static struct pci_driver rust_pci_driver = { .name = "rust_pci_driver_sample", .id_table = rust_pci_driver_ids, .probe = rust_pci_driver_probe, .remove = rust_pci_driver_remove, }; static int __init rust_pci_driver_init(void) { return pci_register_driver(&rust_pci_driver); } static void __exit rust_pci_driver_exit(void) { pci_unregister_driver(&rust_pci_driver); } module_init(rust_pci_driver_init); module_exit(rust_pci_driver_exit); MODULE_AUTHOR("Danilo Krummrich"); MODULE_DESCRIPTION("Rust PCI driver sample"); MODULE_LICENSE("GPL v2"); ``` --- `samples/rust/rust_pci_sample/driver.rs` ```Rust [7, 14, 21|10,11,12,20] // SPDX-License-Identifier: GPL-2.0 //! Rust PCI driver sample use kernel::{bindings, prelude::*}; const __LOG_PREFIX: &[u8] = b"rust_pci_driver_sample\0"; #[no_mangle] unsafe extern "C" fn rust_pci_driver_probe( _pdev: *mut bindings::pci_dev, _ent: bindings::pci_device_id, ) -> core::ffi::c_int { pr_info!("Probe Rust PCI driver sample.\n"); 0 } #[no_mangle] unsafe extern "C" fn rust_pci_driver_remove(_pdev: *mut bindings::pci_dev) { pr_info!("Remove Rust PCI driver sample.\n"); } ``` --- ## Issues - drivers are required to define `__LOG_PREFIX` manually - unsafe driver entry points (`probe()`, `remove()`, `suspend()`, `resume()`, etc.) - drivers require messy workarounds to bind object lifetimes to the driver and module lifetimes - more complex drivers often create objects in the module scope (e.g. `kmem_cache`, debug structures, etc.) - Rust code typically binds object lifetimes to the scope the object live in - Rust kernel module abstractions already provide such a scope with their object representation
## Driver registration in Rust Module --- `samples/rust/rust_pci_sample/mod.rs` ```Rust [29-35|45] // SPDX-License-Identifier: GPL-2.0 //! Rust PCI driver sample. mod driver; use core::ptr; use kernel::{bindings, error::to_result, prelude::*}; module! { type: Module, name: "rust_pci_driver_sample", author: "Danilo Krummrich", description: "Rust PCI driver sample", license: "GPL", } struct Module; impl kernel::Module for Module { fn init(name: &'static CStr, module: &'static ThisModule) -> Result
{ // SAFETY: `driver::DRIVER` is a valid `struct pci_driver`; `ThisModule` //is equivalent to C's `THIS_MODULE` and hence valid for // `__pci_register_driver`. `name` is passed as `NULL` terminated C // string. // // Returns zero when the driver was registered successfully, a non-zero // error code otherwise, which is handled by `to_result`. to_result(unsafe { bindings::__pci_register_driver( ptr::addr_of_mut!(driver::DRIVER), module.as_ptr(), name.as_char_ptr(), ) })?; Ok(Module) } } impl Drop for Module { fn drop(&mut self) { // SAFETY: `Module::drop` is only ever called when `driver::DRIVER` was // registered successfully. unsafe { bindings::pci_unregister_driver(ptr::addr_of_mut!(driver::DRIVER)) }; } } ``` --- `samples/rust/rust_pci_sample/driver.rs` ```Rust [17-27|29-42] // SPDX-License-Identifier: GPL-2.0 //! Rust PCI driver sample use kernel::{bindings, c_str, prelude::*}; const PCI_DEVICE_ID_REDHAT_QEMU_PCI_TESTDEV: u32 = 0x0005; pub(crate) static mut DRIVER: bindings::pci_driver = Driver::driver(); struct Driver; impl Driver { const IDS: usize = 2; const __ID_TABLE: [bindings::pci_device_id; Self::IDS] = Self::id_table(); const fn driver() -> bindings::pci_driver { // SAFETY: `bindings::pci_driver` is valid to be zero initialized. let mut drv: bindings::pci_driver = unsafe { core::mem::zeroed() }; drv.name = c_str!("rust_pci_driver_sample").as_char_ptr(); drv.id_table = Self::__ID_TABLE.as_ptr(); drv.probe = Some(Self::probe); drv.remove = Some(Self::remove); drv } const fn id_table() -> [bindings::pci_device_id; 2] { // SAFETY: `bindings::pci_device_id` is valid to be zero initialized. let mut id: bindings::pci_device_id = unsafe { core::mem::zeroed() }; id.vendor = bindings::PCI_VENDOR_ID_REDHAT; id.device = PCI_DEVICE_ID_REDHAT_QEMU_PCI_TESTDEV; id.subvendor = bindings::PCI_ANY_ID as u32; id.subdevice = bindings::PCI_ANY_ID as u32; // SAFETY: `bindings::pci_device_id` is valid to be zero initialized. let sentinel: bindings::pci_device_id = unsafe { core::mem::zeroed() }; [id, sentinel] } extern "C" fn probe( _pdev: *mut bindings::pci_dev, _ent: *const bindings::pci_device_id, ) -> core::ffi::c_int { pr_info!("Probe Rust PCI driver sample.\n"); 0 } extern "C" fn remove(_pdev: *mut bindings::pci_dev) { pr_info!("Remove Rust PCI driver sample.\n"); } } ``` --- ## Issues - drivers need unsafe function calls (`__pci_register_driver()`, `pci_unregister_driver()`) - drivers can forget to unregister the driver structure - we have to juggle the device ID table and potential driver data as raw pointers - drivers must remember to put the sentinel for the devie ID table
## PCI specific Registration structure --- `rust/kernel/pci.rs` ```Rust [17-33|37-54|60-63|70-95|101-105] // SPDX-License-Identifier: GPL-2.0 //! Wrappers for the PCI subsystem //! //! C header: [`include/linux/pci.h`](srctree/include/linux/pci.h) use core::cell::UnsafeCell; use core::marker::PhantomData; use kernel::{ alloc::flags::*, bindings, error::{from_result, to_result}, prelude::*, }; /// Drivers must implement this trait to register a PCI driver. pub trait Driver { /// Pointer to the PCI driver's ID table. /// /// This is highly unsafe, since we have to trust the driver, that the provided pointer is /// valid. const ID_TABLE: *const bindings::pci_device_id; /// PCI driver probe. /// /// Called when a PCI device is matched against a PCI driver. fn probe(pdev: *mut bindings::pci_dev) -> Result; /// PCI driver remove. /// /// Called when the PCI device is unbound. fn remove(pdev: *mut bindings::pci_dev); } struct Adapter
(PhantomData
); impl
Adapter
where T: Driver, { extern "C" fn probe( pdev: *mut bindings::pci_dev, _ent: *const bindings::pci_device_id, ) -> core::ffi::c_int { from_result(|| { T::probe(pdev)?; Ok(0) }) } extern "C" fn remove(pdev: *mut bindings::pci_dev) { T::remove(pdev); } } /// Registration structure for a PCI driver. /// /// The existance of an instance of this structure implies that the corresponding PCI driver is /// currently registered. pub struct Registration
{ driver: Pin
>>, _p: PhantomData
, } impl
Registration
where T: Driver, { /// Register a new PCI driver from `T: Driver`. pub fn new(name: &'static CStr, module: &'static ThisModule) -> Result
{ let mut driver = KBox::pin(UnsafeCell::new(bindings::pci_driver::default()), GFP_KERNEL)?; // Abuse that `bindings::pci_driver` is `Unpin`. let inner = driver.get_mut(); inner.name = name.as_char_ptr(); inner.probe = Some(Adapter::
::probe); inner.remove = Some(Adapter::
::remove); inner.id_table = T::ID_TABLE; // SAFETY: `driver` is a valid `struct pci_driver`; `ThisModule` is equivalent to // C's `THIS_MODULE` and hence valid for `__pci_register_driver`. `name` is passed as `NULL` // terminated C string. // // Returns zero when the driver was registered successfully, a non-zero error code // otherwise, which is handled by `to_result`. to_result(unsafe { bindings::__pci_register_driver(driver.get(), module.as_ptr(), name.as_char_ptr()) })?; Ok(Self { driver, _p: PhantomData::
, }) } } impl
Drop for Registration
where T: Driver, { fn drop(&mut self) { // SAFETY: `Module::drop` is only ever called when `self.drv` was registered // successfully. unsafe { bindings::pci_unregister_driver(self.driver.get()) }; } } // SAFETY: `Registration` has no fields or methods accessible via `&Registration`, so it is safe to // share references to it with multiple threads as nothing can be done. unsafe impl
Sync for Registration
where T: Driver {} // SAFETY: Both registration and unregistration are implemented in C and safe to be performed from // any thread, so `Registration` is `Send`. unsafe impl
Send for Registration
where T: Driver {} ``` --- `samples/rust/rust_pci_driver/mod.rs` ```Rust [17-27] // SPDX-License-Identifier: GPL-2.0 //! Rust PCI driver sample. mod driver; use kernel::{pci, prelude::*}; module! { type: Module, name: "rust_pci_driver_sample", author: "Danilo Krummrich", description: "Rust PCI driver sample", license: "GPL", } struct Module { _reg: pci::Registration
, } impl kernel::Module for Module { fn init(name: &'static CStr, module: &'static ThisModule) -> Result
{ Ok(Module { _reg: pci::Registration::new(name, module)?, }) } } ``` --- `samples/rust/rust_pci_driver/driver.rs` ```Rust [31-43] // SPDX-License-Identifier: GPL-2.0 //! Rust PCI driver sample use kernel::{bindings, pci, prelude::*}; const PCI_DEVICE_ID_REDHAT_QEMU_PCI_TESTDEV: u32 = 0x0005; pub(crate) struct Driver; impl Driver { const IDS: usize = 2; const __ID_TABLE: [bindings::pci_device_id; Self::IDS] = Self::id_table(); const fn id_table() -> [bindings::pci_device_id; Self::IDS] { // SAFETY: `bindings::pci_device_id` is valid to be zero initialized. let mut id: bindings::pci_device_id = unsafe { core::mem::zeroed() }; id.vendor = bindings::PCI_VENDOR_ID_REDHAT; id.device = PCI_DEVICE_ID_REDHAT_QEMU_PCI_TESTDEV; id.subvendor = bindings::PCI_ANY_ID as u32; id.subdevice = bindings::PCI_ANY_ID as u32; // SAFETY: `bindings::pci_device_id` is valid to be zero initialized. let sentinel: bindings::pci_device_id = unsafe { core::mem::zeroed() }; [id, sentinel] } } impl pci::Driver for Driver { const ID_TABLE: *const bindings::pci_device_id = Self::__ID_TABLE.as_ptr(); fn probe(_pdev: *mut bindings::pci_dev) -> Result { pr_info!("Probe Rust PCI driver sample.\n"); Ok(()) } fn remove(_pdev: *mut bindings::pci_dev) { pr_info!("Remove Rust PCI driver sample.\n"); } } ``` --- ## Issues resolved - `Registration` type ensures that a driver is unregistered eventually
## Generic Registration structure --- `rust/kernel/driver.rs` ```Rust [16-42|49-52|59-70|77-80] // SPDX-License-Identifier: GPL-2.0 //! Generic driver support use core::cell::UnsafeCell; use core::marker::PhantomData; use kernel::{alloc::flags::*, prelude::*}; /// The [`RegistrationOps`] trait serves as generic interface for subsystems (e.g., PCI, Platform, /// Amba, etc.) to privide the corresponding subsystem specific implementation to register / /// unregister a driver of the particular type (`RegType`). /// /// For instance, the PCI subsystem would set `RegType` to `bindings::pci_driver` and call /// `bindings::__pci_register_driver` from `RegistrationOps::register` and /// `bindings::pci_unregister_driver` from `RegistrationOps::unregister`. pub trait RegistrationOps { /// The type that holds information about the registration. This is typically a struct defined /// by the C portion of the kernel, e.g. `bindings::pci_driver. type RegType: Default; /// Registers a driver. /// /// # Safety /// /// `reg` must point to valid, initialised, and writable memory. It may be modified by this /// function to hold registration state. /// /// On success, `reg` must remain pinned and valid until the matching call to /// [`RegistrationOps::unregister`]. unsafe fn register( reg: *mut Self::RegType, name: &'static CStr, module: &'static ThisModule, ) -> Result; /// Unregisters a driver previously registered with [`RegistrationOps::register`]. /// /// # Safety /// /// `reg` must point to valid writable memory, initialised by a previous successful call to /// [`RegistrationOps::register`]. unsafe fn unregister(reg: *mut Self::RegType); } /// Registration structure for a driver. /// /// The existance of an instance of this structure implies that the corresponding driver is /// currently registered. pub struct Registration
{ driver: Pin
>>, _p: PhantomData
, } impl
Registration
where T: RegistrationOps, { /// Register a new driver from `T::RegType`. pub fn new(name: &'static CStr, module: &'static ThisModule) -> Result
{ let driver = KBox::pin(UnsafeCell::new(T::RegType::default()), GFP_KERNEL)?; // SAFETY: `driver` has just been initialized; it's only freed after `Self::drop`, which // calls `T::unregister` first. unsafe { T::register(driver.get(), name, module) }?; Ok(Self { driver, _p: PhantomData::
, }) } } impl
Drop for Registration
where T: RegistrationOps, { fn drop(&mut self) { // SAFETY: Only ever called if the `Registration` was created successfully. unsafe { T::unregister(self.driver.get()) }; } } // SAFETY: `Registration` has no fields or methods accessible via `&Registration`, so it is safe to // share references to it with multiple threads as nothing can be done. unsafe impl
Sync for Registration
where T: RegistrationOps {} // SAFETY: Both registration and unregistration are implemented in C and safe to be performed from // any thread, so `Registration` is `Send`. unsafe impl
Send for Registration
where T: RegistrationOps {} ``` --- `rust/kernel/pci.rs` ```Rust [55-90] // SPDX-License-Identifier: GPL-2.0 //! Wrappers for the PCI subsystem //! //! C header: [`include/linux/pci.h`](srctree/include/linux/pci.h) use core::marker::PhantomData; use kernel::{ bindings, driver, error::{from_result, to_result}, prelude::*, }; /// Drivers must implement this trait to register a PCI driver. pub trait Driver { /// Pointer to the PCI driver's ID table. /// /// This is highly unsafe, since we have to trust the driver, that the provided pointer is /// valid. const ID_TABLE: *const bindings::pci_device_id; /// PCI driver probe. /// /// Called when a PCI device is matched against a PCI driver. fn probe(pdev: *mut bindings::pci_dev) -> Result; /// PCI driver remove. /// /// Called when the PCI device is unbound. fn remove(pdev: *mut bindings::pci_dev); } /// PCI abstraction for registering PCI drivers. pub struct Adapter
(PhantomData
); impl
Adapter
where T: Driver, { extern "C" fn probe( pdev: *mut bindings::pci_dev, _ent: *const bindings::pci_device_id, ) -> core::ffi::c_int { from_result(|| { T::probe(pdev)?; Ok(0) }) } extern "C" fn remove(pdev: *mut bindings::pci_dev) { T::remove(pdev); } } impl
driver::RegistrationOps for Adapter
where T: Driver, { type RegType = bindings::pci_driver; unsafe fn register( pdrv: *mut Self::RegType, name: &'static CStr, module: &'static ThisModule, ) -> Result { // SAFETY: By the safety requirements of this function `pdrv` is valid; we never move out // of `pdrv`. let pdrv = unsafe { &mut *pdrv }; pdrv.name = name.as_char_ptr(); pdrv.probe = Some(Self::probe); pdrv.remove = Some(Self::remove); pdrv.id_table = T::ID_TABLE; // SAFETY: `pdrv` is a valid `struct pci_driver`; `ThisModule` is equivalent to // C's `THIS_MODULE` and hence valid for `__pci_register_driver`. `name` is passed as `NULL` // terminated C string. // // Returns zero when the driver was registered successfully, a non-zero error code // otherwise, which is handled by `to_result`. to_result(unsafe { bindings::__pci_register_driver(pdrv, module.as_ptr(), name.as_char_ptr()) }) } unsafe fn unregister(pdrv: *mut Self::RegType) { // SAFETY: `pdrv` is guaranteed to be a valid `RegType`. unsafe { bindings::pci_unregister_driver(pdrv) } } } ``` --- `samples/rust/rust_pci_driver/mod.rs` ```Rust [18,24] // SPDX-License-Identifier: GPL-2.0 //! Rust PCI driver sample. mod driver; use kernel::{pci, prelude::*}; module! { type: Module, name: "rust_pci_driver_sample", author: "Danilo Krummrich", description: "Rust PCI driver sample", license: "GPL", } struct Module { _reg: kernel::driver::Registration
>, } impl kernel::Module for Module { fn init(name: &'static CStr, module: &'static ThisModule) -> Result
{ Ok(Module { _reg: kernel::driver::Registration::new(name, module)?, }) } } ``` --- ## Issues resolved - Generalized the `Registration` type - really starts to pay out when we stop using `Box` to allocate driver structures, but use static allocation through `InPlaceModule` instead
## PCI Device ID abstraction --- `rust/kernel/pci.rs` ```Rust [16-29|31-62|66-69|71-91|97-105|137-145|151] // SPDX-License-Identifier: GPL-2.0 //! Wrappers for the PCI subsystem //! //! C header: [`include/linux/pci.h`](srctree/include/linux/pci.h) use core::marker::PhantomData; use kernel::{ bindings, driver, error::{from_result, to_result}, prelude::*, }; /// Abstraction for `bindings::pci_device_id`. #[derive(Clone, Copy)] pub struct DeviceId { /// Vendor ID pub vendor: u32, /// Device ID pub device: u32, /// Subsystem vendor ID pub subvendor: u32, /// Subsystem device ID pub subdevice: u32, /// Device class and subclass pub class: u32, /// Limit which sub-fields of the class pub class_mask: u32, } impl DeviceId { /// Zeroed `bindings::pci_device_id`. // SAFETY; The all-zero byte-pattern is valid for `bindings::pci_device_id`. pub const ZERO: bindings::pci_device_id = unsafe { core::mem::zeroed() }; const PCI_ANY_ID: u32 = !0; /// Equivalent to the PCI_DEVICE macro. pub const fn new(vendor: u32, device: u32) -> Self { Self { vendor, device, subvendor: DeviceId::PCI_ANY_ID, subdevice: DeviceId::PCI_ANY_ID, class: 0, class_mask: 0, } } /// Convert `DeviceId` to raw `bindings::pci_device_id`. pub const fn to_rawid(&self) -> bindings::pci_device_id { let mut raw = Self::ZERO; raw.vendor = self.vendor; raw.device = self.device; raw.subvendor = self.subvendor; raw.subdevice = self.subdevice; raw.class = self.class; raw.class_mask = self.class_mask; raw } } /// A zero-terminated PCI device ID array. #[repr(C)] pub struct IdArray
{ ids: [bindings::pci_device_id; N], sentinel: bindings::pci_device_id, } impl
IdArray
{ /// Creates a new instance of the ID array. /// /// The contents are derived from the given identifiers. #[doc(hidden)] pub const fn new(ids: [bindings::pci_device_id; N]) -> Self { Self { ids, sentinel: DeviceId::ZERO, } } /// Returns an `IdTable` backed by `self`. /// /// This is used to essentially erase the array size. pub const fn as_table(&self) -> IdTable<'_> { IdTable { first: &self.ids[0], } } } /// A device ID table. /// /// The table is guaranteed to be zero-terminated. #[repr(C)] pub struct IdTable<'a> { first: &'a bindings::pci_device_id, } impl AsRef
for IdTable<'_> { fn as_ref(&self) -> &bindings::pci_device_id { self.first } } /// Counts the number of parenthesis-delimited, comma-separated items. #[macro_export] macro_rules! count_paren_items { (($($item:tt)*), $($remaining:tt)*) => { 1 + $crate::count_paren_items!($($remaining)*) }; (($($item:tt)*)) => { 1 }; () => { 0 }; } #[macro_export] #[doc(hidden)] macro_rules! define_pci_id_array { ($($args:tt)*) => {{ const fn new
(ids: [$crate::pci::DeviceId; N]) -> $crate::pci::IdArray
{ let mut raw_ids = [$crate::pci::DeviceId::ZERO; N]; let mut i = 0usize; while i < N { raw_ids[i] = ids[i].to_rawid(); i += 1; } $crate::pci::IdArray::
::new(raw_ids) } new([ $($args)* ]) }} } /// Define a const PCI device ID table. #[macro_export] macro_rules! define_pci_id_table { ([ $($args:tt)* ]) => { const ID_TABLE: $crate::pci::IdTable<'static> = { const ARRAY: $crate::pci::IdArray<{ $crate::count_paren_items!($($args)*) }> = $crate::define_pci_id_array!($($args)*); ARRAY.as_table() }; }; } pub use define_pci_id_table; /// Drivers must implement this trait to register a PCI driver. pub trait Driver { /// The table of device IDs supported by this driver. const ID_TABLE: IdTable<'static>; /// PCI driver probe. /// /// Called when a PCI device is matched against a PCI driver. fn probe(pdev: *mut bindings::pci_dev) -> Result; /// PCI driver remove. /// /// Called when the PCI device is unbound. fn remove(pdev: *mut bindings::pci_dev); } /// PCI abstraction for registering PCI drivers. pub struct Adapter
(PhantomData
); impl
Adapter
where T: Driver, { extern "C" fn probe( pdev: *mut bindings::pci_dev, _ent: *const bindings::pci_device_id, ) -> core::ffi::c_int { from_result(|| { T::probe(pdev)?; Ok(0) }) } extern "C" fn remove(pdev: *mut bindings::pci_dev) { T::remove(pdev); } } impl
driver::RegistrationOps for Adapter
where T: Driver, { type RegType = bindings::pci_driver; unsafe fn register( pdrv: *mut Self::RegType, name: &'static CStr, module: &'static ThisModule, ) -> Result { // SAFETY: By the safety requirements of this function `pdrv` is valid; we never move out // of `pdrv`. let pdrv = unsafe { &mut *pdrv }; pdrv.name = name.as_char_ptr(); pdrv.probe = Some(Self::probe); pdrv.remove = Some(Self::remove); pdrv.id_table = T::ID_TABLE.as_ref(); // SAFETY: `pdrv` is a valid `struct pci_driver`; `ThisModule` is equivalent to // C's `THIS_MODULE` and hence valid for `__pci_register_driver`. `name` is passed as `NULL` // terminated C string. // // Returns zero when the driver was registered successfully, a non-zero error code // otherwise, which is handled by `to_result`. to_result(unsafe { bindings::__pci_register_driver(pdrv, module.as_ptr(), name.as_char_ptr()) }) } unsafe fn unregister(pdrv: *mut Self::RegType) { // SAFETY: `pdrv` is guaranteed to be a valid `RegType`. unsafe { bindings::pci_unregister_driver(pdrv) } } } ``` --- `samples/rust/rust_pci_driver/driver.rs` ```Rust [12-15] // SPDX-License-Identifier: GPL-2.0 //! Rust PCI driver sample use kernel::{bindings, pci, pci::define_pci_id_table, prelude::*}; const PCI_DEVICE_ID_REDHAT_QEMU_PCI_TESTDEV: u32 = 0x0005; pub(crate) struct Driver; impl pci::Driver for Driver { define_pci_id_table! { [ (pci::DeviceId::new(bindings::PCI_VENDOR_ID_REDHAT, PCI_DEVICE_ID_REDHAT_QEMU_PCI_TESTDEV)) ] } fn probe(_pdev: *mut bindings::pci_dev) -> Result { pr_info!("Probe Rust PCI driver sample.\n"); Ok(()) } fn remove(_pdev: *mut bindings::pci_dev) { pr_info!("Remove Rust PCI driver sample.\n"); } } ``` --- ## Issues resolved - device ID table initialization does not require raw pointers on the driver side anymore - sentinel is added automatically at the end of the device ID array --- `rust/kernel/pci.rs` ```Rust [52,61|72,84|100-105|169-174|200,203,208|230-241,244] // SPDX-License-Identifier: GPL-2.0 //! Wrappers for the PCI subsystem //! //! C header: [`include/linux/pci.h`](srctree/include/linux/pci.h) use core::marker::PhantomData; use kernel::{ bindings, driver, error::{from_result, to_result}, prelude::*, }; /// Abstraction for `bindings::pci_device_id`. #[derive(Clone, Copy)] pub struct DeviceId { /// Vendor ID pub vendor: u32, /// Device ID pub device: u32, /// Subsystem vendor ID pub subvendor: u32, /// Subsystem device ID pub subdevice: u32, /// Device class and subclass pub class: u32, /// Limit which sub-fields of the class pub class_mask: u32, } impl DeviceId { /// Zeroed `bindings::pci_device_id`. // SAFETY; The all-zero byte-pattern is valid for `bindings::pci_device_id`. pub const ZERO: bindings::pci_device_id = unsafe { core::mem::zeroed() }; const PCI_ANY_ID: u32 = !0; /// Equivalent to the PCI_DEVICE macro. pub const fn new(vendor: u32, device: u32) -> Self { Self { vendor, device, subvendor: DeviceId::PCI_ANY_ID, subdevice: DeviceId::PCI_ANY_ID, class: 0, class_mask: 0, } } /// Convert `DeviceId` to raw `bindings::pci_device_id`. /// /// `offset` is the offset of `ids[i]` to `id_infos[i]` within `IdArray`. pub const fn to_rawid(&self, offset: usize) -> bindings::pci_device_id { let mut raw = Self::ZERO; raw.vendor = self.vendor; raw.device = self.device; raw.subvendor = self.subvendor; raw.subdevice = self.subdevice; raw.class = self.class; raw.class_mask = self.class_mask; raw.driver_data = offset as _; raw } } /// A zero-terminated PCI device ID array. #[repr(C)] pub struct IdArray
{ ids: [bindings::pci_device_id; N], sentinel: bindings::pci_device_id, id_infos: [Option
; N], } impl
IdArray
{ /// Creates a new instance of the ID array. /// /// The contents are derived from the given identifiers. #[doc(hidden)] pub const fn new(ids: [bindings::pci_device_id; N], infos: [Option
; N]) -> Self { Self { ids, sentinel: DeviceId::ZERO, id_infos: infos, } } /// Returns an `IdTable` backed by `self`. /// /// This is used to essentially erase the array size. pub const fn as_table(&self) -> IdTable<'_, U> { IdTable { first: &self.ids[0], _p: PhantomData, } } /// Returns the offset of `ids[i]` to `id_infos[i]` within `IdArray`. #[doc(hidden)] pub const fn get_offset(index: usize) -> usize { let id_size = core::mem::size_of::
(); let info_size = core::mem::size_of::
>(); id_size * (N - index + 1) + info_size * index } } /// A device ID table. /// /// The table is guaranteed to be zero-terminated. #[repr(C)] pub struct IdTable<'a, U> { first: &'a bindings::pci_device_id, _p: PhantomData<&'a U>, } impl
AsRef
for IdTable<'_, U> { fn as_ref(&self) -> &bindings::pci_device_id { self.first } } /// Converts a comma-separated list of pairs into an array with the first element. That is, it /// discards the second element of the pair. /// /// Additionally, it automatically introduces a type if the first element is warpped in curly /// braces, for example, if it's `{v: 10}`, it becomes `X { v: 10 }`; this is to avoid repeating /// the type. #[macro_export] macro_rules! first_item { ($id_type:ty, $(({$($first:tt)*}, $second:expr)),* $(,)?) => { { type IdType = $id_type; [$(IdType{$($first)*},)*] } }; ($id_type:ty, $(($first:expr, $second:expr)),* $(,)?) => { [$($first,)*] }; } /// Converts a comma-separated list of pairs into an array with the second element. That is, it /// discards the first element of the pair. #[macro_export] macro_rules! second_item { ($(({$($first:tt)*}, $second:expr)),* $(,)?) => { [$($second,)*] }; ($(($first:expr, $second:expr)),* $(,)?) => { [$($second,)*] }; } /// Counts the number of parenthesis-delimited, comma-separated items. #[macro_export] macro_rules! count_paren_items { (($($item:tt)*), $($remaining:tt)*) => { 1 + $crate::count_paren_items!($($remaining)*) }; (($($item:tt)*)) => { 1 }; () => { 0 }; } #[macro_export] #[doc(hidden)] macro_rules! define_pci_id_array { ($id_type:ty, $($args:tt)*) => {{ const fn new
( ids: [$crate::pci::DeviceId; N], infos: [Option
; N]) -> $crate::pci::IdArray
{ let mut raw_ids = [$crate::pci::DeviceId::ZERO; N]; // The offset of `ids[i]` to `id_infos[i]` within `IdArray`. let mut i = 0usize; while i < N { let offset = $crate::pci::IdArray::
::get_offset(i); raw_ids[i] = ids[i].to_rawid(offset); i += 1; } $crate::pci::IdArray::
::new(raw_ids, infos) } new($crate::first_item!($id_type, $($args)*), $crate::second_item!($($args)*)) }} } /// Define a const PCI device ID table. #[macro_export] macro_rules! define_pci_id_table { ($id_type:ty, [ $($args:tt)* ]) => { type IdInfo = $id_type; const ID_TABLE: $crate::pci::IdTable<'static, $id_type> = { const ARRAY: $crate::pci::IdArray<$id_type, { $crate::count_paren_items!($($args)*) }> = $crate::define_pci_id_array!($id_type, $($args)*); ARRAY.as_table() }; }; } pub use define_pci_id_table; /// Drivers must implement this trait to register a PCI driver. pub trait Driver { /// The type holding information about each device ID supported by the drivID. type IdInfo: 'static; /// The table of device IDs supported by this driver. const ID_TABLE: IdTable<'static, Self::IdInfo>; /// PCI driver probe. /// /// Called when a PCI device is matched against a PCI driver. fn probe(pdev: *mut bindings::pci_dev, id: Option<&Self::IdInfo>) -> Result; /// PCI driver remove. /// /// Called when the PCI device is unbound. fn remove(pdev: *mut bindings::pci_dev); } /// PCI abstraction for registering PCI drivers. pub struct Adapter
(PhantomData
); impl
Adapter
where T: Driver, { extern "C" fn probe( pdev: *mut bindings::pci_dev, id: *const bindings::pci_device_id, ) -> core::ffi::c_int { // SAFETY: `id` is a pointer within the static table, so it's always valid. let offset = unsafe { (*id).driver_data }; let info = { // SAFETY: The offset comes from a previous call to `offset_from` in `IdArray::new`, // which guarantees that the resulting pointer is within the table. let ptr = unsafe { id.cast::
() .offset(offset as _) .cast::
>() }; // SAFETY: Guaranteed by the preceding safety requirement. unsafe { (*ptr).as_ref() } }; from_result(|| { T::probe(pdev, info)?; Ok(0) }) } extern "C" fn remove(pdev: *mut bindings::pci_dev) { T::remove(pdev); } } impl
driver::RegistrationOps for Adapter
where T: Driver, { type RegType = bindings::pci_driver; unsafe fn register( pdrv: *mut Self::RegType, name: &'static CStr, module: &'static ThisModule, ) -> Result { // SAFETY: By the safety requirements of this function `pdrv` is valid; we never move out // of `pdrv`. let pdrv = unsafe { &mut *pdrv }; pdrv.name = name.as_char_ptr(); pdrv.probe = Some(Self::probe); pdrv.remove = Some(Self::remove); pdrv.id_table = T::ID_TABLE.as_ref(); // SAFETY: `pdrv` is a valid `struct pci_driver`; `ThisModule` is equivalent to // C's `THIS_MODULE` and hence valid for `__pci_register_driver`. `name` is passed as `NULL` // terminated C string. // // Returns zero when the driver was registered successfully, a non-zero error code // otherwise, which is handled by `to_result`. to_result(unsafe { bindings::__pci_register_driver(pdrv, module.as_ptr(), name.as_char_ptr()) }) } unsafe fn unregister(pdrv: *mut Self::RegType) { // SAFETY: `pdrv` is guaranteed to be a valid `RegType`. unsafe { bindings::pci_unregister_driver(pdrv) } } } ``` --- `samples/rust/rust_pci_driver/driver.rs` ```Rust [12-16|18,20] // SPDX-License-Identifier: GPL-2.0 //! Rust PCI driver sample use kernel::{bindings, pci, pci::define_pci_id_table, prelude::*}; const PCI_DEVICE_ID_REDHAT_QEMU_PCI_TESTDEV: u32 = 0x0005; pub(crate) struct Driver; impl pci::Driver for Driver { define_pci_id_table! { (), [ (pci::DeviceId::new(bindings::PCI_VENDOR_ID_REDHAT, PCI_DEVICE_ID_REDHAT_QEMU_PCI_TESTDEV), None) ] } fn probe(_pdev: *mut bindings::pci_dev, id: Option<&Self::IdInfo>) -> Result { pr_info!("Probe Rust PCI driver sample.\n"); pr_info!("Info: {:?}\n", id); Ok(()) } fn remove(_pdev: *mut bindings::pci_dev) { pr_info!("Remove Rust PCI driver sample.\n"); } } ``` --- ## Issues resolved - support driver private data in a (type) safe way --- ## Outlook - just like `Registration` this could be generalized for all busses / subsystems - introduce generic for the ID type (e.g. struct pci_device_id) for `IdArray` and `IdTable`
## Statically allocate driver structures --- `rust/kernel/driver.rs` ```Rust [48-52|60-69] // SPDX-License-Identifier: GPL-2.0 //! Generic driver support use crate::types::Opaque; use kernel::prelude::*; /// The [`RegistrationOps`] trait serves as generic interface for subsystems (e.g., PCI, Platform, /// Amba, etc.) to privide the corresponding subsystem specific implementation to register / /// unregister a driver of the particular type (`RegType`). /// /// For instance, the PCI subsystem would set `RegType` to `bindings::pci_driver` and call /// `bindings::__pci_register_driver` from `RegistrationOps::register` and /// `bindings::pci_unregister_driver` from `RegistrationOps::unregister`. pub trait RegistrationOps { /// The type that holds information about the registration. This is typically a struct defined /// by the C portion of the kernel, e.g. `bindings::pci_driver. type RegType: Default; /// Registers a driver. /// /// # Safety /// /// `reg` must point to valid, initialised, and writable memory. It may be modified by this /// function to hold registration state. /// /// On success, `reg` must remain pinned and valid until the matching call to /// [`RegistrationOps::unregister`]. unsafe fn register( reg: *mut Self::RegType, name: &'static CStr, module: &'static ThisModule, ) -> Result; /// Unregisters a driver previously registered with [`RegistrationOps::register`]. /// /// # Safety /// /// `reg` must point to valid writable memory, initialised by a previous successful call to /// [`RegistrationOps::register`]. unsafe fn unregister(reg: *mut Self::RegType); } /// Registration structure for a driver. /// /// The existance of an instance of this structure implies that the corresponding driver is /// currently registered. #[pin_data(PinnedDrop)] pub struct Registration
{ #[pin] driver: Opaque
, } impl
Registration
where T: RegistrationOps, { /// Register a new driver from `T::RegType`. pub fn new(name: &'static CStr, module: &'static ThisModule) -> impl PinInit
{ try_pin_init!(Self { driver <- Opaque::try_ffi_init(|ptr: *mut T::RegType| { // SAFETY: `try_ffi_init` guarantees that `ptr` is valid for write. unsafe { ptr.write(T::RegType::default()) }; // SAFETY: `driver` has just been initialized; `T::unregister` is called on // `Self::drop`. unsafe { T::register(ptr, name, module) } }), }) } } #[pinned_drop] impl
PinnedDrop for Registration
where T: RegistrationOps, { fn drop(self: Pin<&mut Self>) { // SAFETY: Only ever called if the `Registration` was created successfully. unsafe { T::unregister(self.driver.get()) }; } } // SAFETY: `Registration` has no fields or methods accessible via `&Registration`, so it is safe to // share references to it with multiple threads as nothing can be done. unsafe impl
Sync for Registration
where T: RegistrationOps {} // SAFETY: Both registration and unregistration are implemented in C and safe to be performed from // any thread, so `Registration` is `Send`. unsafe impl
Send for Registration
where T: RegistrationOps {} ``` --- `samples/rust/rust_pci_driver/mod.rs` ```Rust [17-21|25-27] // SPDX-License-Identifier: GPL-2.0 //! Rust PCI driver sample. mod driver; use kernel::{pci, prelude::*}; module! { type: Module, name: "rust_pci_driver_sample", author: "Danilo Krummrich", description: "Rust PCI driver sample", license: "GPL", } #[pin_data] struct Module { #[pin] _reg: kernel::driver::Registration
>, } impl kernel::InPlaceModule for Module { fn init(name: &'static CStr, module: &'static ThisModule) -> impl PinInit
{ try_pin_init!(Module { _reg <- kernel::driver::Registration::new(name, module), }) } } ``` --- ## Issues resolved - statically allocate driver structures
## Outlook - the proposed patch series also implements abstractions that are not part of this discussion; i.e. abstractions for: - `struct pci_dev`, such that `probe()` and `remove()` don't need to juggle raw pointers themselfes - `Io` memory and PCI bars, such that we can do boundary checks on the mappings (partially on compile time) - `Devres`, such that we can easily control access to resources bound to the lifetime of a device / driver binding